RSA OTP [crypto]

RSA OTP

RSA is kinda bad but I strengthened it with the unbreakable one time pad!

Solution

from Crypto.Util.number import bytes_to_long
from Crypto.Random.random import getrandbits
from Crypto.PublicKey import RSA
import math
import pwn

n = 136018504103450744973226909842302068548152091075992057924542109508619184755376768234431340139221594830546350990111376831021784447802637892581966979028826938086172778174904402131356050027973054268478615792292786398076726225353285978936466029682788745325588134172850614459269636474769858467022326624710771957129
e = 0x10001
ef = 17482644844951175640843255713372869422739097498066773957636359990466096121278949693816080016671592558403643716793132479255285512907247513385850323834210899918531077167485767118313722022095603863840851451191536627814100144146010392752308431038754246815068245448456643024387011488032896209253644172833489422733

conn = pwn.remote('crypto.2020.chall.actf.co', 20600)
conn.recv()

def get_len(x):
    conn.send(str(x).encode() + b'\n')
    bins = conn.recv().decode().split('\n')[1]
    return len(bins)

n_samples = 100
c_samples = 0
finds = []
i = 1
next_2n = 1
start, end = 1, 1
while c_samples < n_samples:
    i = (start + end) // 2
    v = (eflag * pow(i, key.e, key.n)) % key.n
    l = get_len(v)
    lb = 1 << (l - 1)
    ub = (1 << l) - 1

    if start == end:
        finds.append((l, i, lb, ub))
        start = (1 << next_2n) + 1
        end = 1 << (next_2n + 1)
        next_2n += 1
        c_samples += 1
        print(f'{(c_samples/n_samples)*100.0}%')
        continue

    if l > finds[-1][0]:
        if end == i:
            end = start
        else:
            end = i
    else:
        if start == i:
            start = end
        else:
            start = i

for f in finds:
    print(f'{f[0]} bits: {f[2]} <= {f[1]}f <= {f[3]}')

def sat(x):
    for (l, i, lb, ub) in finds:
        mf = (x * i) % n
        if mf < lb:
            return -1
        elif mf > ub:
            return 1
    return 0

start, end = finds[0][2], finds[0][3]
while 1:
    pos = (start + end) // 2
    print(pos)
    s = sat(pos)
    if s == -1:
        start = pos
    elif s == 1:
        end = pos
    else:
        print(f'Answer lies between {start} and {end}')
        break
    if start == end:
        print(f'Answer is {start}')
        break